{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "1l8bWGmIJuQa"
      },
      "source": [
        "##### Copyright 2019 The TensorFlow Authors.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "CPSnXS88KFEo"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "89xNCIO5hiCj"
      },
      "source": [
        "# Save and load a model using a distribution strategy"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9Ejs4QVxIdAm"
      },
      "source": [
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://www.tensorflow.org/tutorials/distribute/save_and_load\"><img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" />View on TensorFlow.org</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/tutorials/distribute/save_and_load.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://github.com/tensorflow/docs/blob/master/site/en/tutorials/distribute/save_and_load.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a href=\"https://storage.googleapis.com/tensorflow_docs/docs/site/en/tutorials/distribute/save_and_load.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\" />Download notebook</a>\n",
        "  </td>\n",
        "\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "A0lG6qgThxAS"
      },
      "source": [
        "## Overview\n",
        "\n",
        "This tutorial demonstrates how you can save and load models in a SavedModel format with `tf.distribute.Strategy` during or after training. There are two kinds of APIs for saving and loading a Keras model: high-level (`tf.keras.Model.save` and `tf.keras.models.load_model`) and low-level (`tf.saved_model.save` and `tf.saved_model.load`).\n",
        "\n",
        "To learn about SavedModel and serialization in general, please read the [saved model guide](../../guide/saved_model.ipynb), and the [Keras model serialization guide](https://www.tensorflow.org/guide/keras/save_and_serialize). Let's start with a simple example.\n",
        "\n",
        "Caution: TensorFlow models are code and it is important to be careful with untrusted code. Learn more in [Using TensorFlow securely](https://github.com/tensorflow/tensorflow/blob/master/SECURITY.md).\n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FITHltVKQ4eZ"
      },
      "source": [
        "Import dependencies:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "RWG5HchAiOrZ"
      },
      "outputs": [],
      "source": [
        "import tensorflow_datasets as tfds\n",
        "\n",
        "import tensorflow as tf\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qqapWj98ptNV"
      },
      "source": [
        "Load and prepare the data with TensorFlow Datasets and `tf.data`, and create the model using `tf.distribute.MirroredStrategy`:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "yrYiAf_ziRyw"
      },
      "outputs": [],
      "source": [
        "mirrored_strategy = tf.distribute.MirroredStrategy()\n",
        "\n",
        "def get_data():\n",
        "  datasets = tfds.load(name='mnist', as_supervised=True)\n",
        "  mnist_train, mnist_test = datasets['train'], datasets['test']\n",
        "\n",
        "  BUFFER_SIZE = 10000\n",
        "\n",
        "  BATCH_SIZE_PER_REPLICA = 64\n",
        "  BATCH_SIZE = BATCH_SIZE_PER_REPLICA * mirrored_strategy.num_replicas_in_sync\n",
        "\n",
        "  def scale(image, label):\n",
        "    image = tf.cast(image, tf.float32)\n",
        "    image /= 255\n",
        "\n",
        "    return image, label\n",
        "\n",
        "  train_dataset = mnist_train.map(scale).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)\n",
        "  eval_dataset = mnist_test.map(scale).batch(BATCH_SIZE)\n",
        "\n",
        "  return train_dataset, eval_dataset\n",
        "\n",
        "def get_model():\n",
        "  with mirrored_strategy.scope():\n",
        "    model = tf.keras.Sequential([\n",
        "        tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),\n",
        "        tf.keras.layers.MaxPooling2D(),\n",
        "        tf.keras.layers.Flatten(),\n",
        "        tf.keras.layers.Dense(64, activation='relu'),\n",
        "        tf.keras.layers.Dense(10)\n",
        "    ])\n",
        "\n",
        "    model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
        "                  optimizer=tf.keras.optimizers.Adam(),\n",
        "                  metrics=[tf.metrics.SparseCategoricalAccuracy()])\n",
        "    return model"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qmU4Y3feS9Na"
      },
      "source": [
        "Train the model with `tf.keras.Model.fit`: "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "zmGurbJmS_vN"
      },
      "outputs": [],
      "source": [
        "model = get_model()\n",
        "train_dataset, eval_dataset = get_data()\n",
        "model.fit(train_dataset, epochs=2)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "L01wjgvRizHS"
      },
      "source": [
        "## Save and load the model\n",
        "\n",
        "Now that you have a simple model to work with, let's explore the saving/loading APIs. \n",
        "There are two kinds of APIs available:\n",
        "\n",
        "*   High-level (Keras): `Model.save` and `tf.keras.models.load_model` (`.keras` zip archive format)\n",
        "*   Low-level: `tf.saved_model.save` and `tf.saved_model.load` (TF SavedModel format)\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FX_IF2F1tvFs"
      },
      "source": [
        "### The Keras API"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "O8xfceg4Z3H_"
      },
      "source": [
        "Here is an example of saving and loading a model with the Keras API:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "LYOStjV5knTQ"
      },
      "outputs": [],
      "source": [
        "keras_model_path = '/tmp/keras_save.keras'\n",
        "model.save(keras_model_path)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "yvQIdQp3zNMp"
      },
      "source": [
        "Restore the model without `tf.distribute.Strategy`:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "WrXAAVtrzRgv"
      },
      "outputs": [],
      "source": [
        "restored_keras_model = tf.keras.models.load_model(keras_model_path)\n",
        "restored_keras_model.fit(train_dataset, epochs=2)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gYAnskzorda-"
      },
      "source": [
        "After restoring the model, you can continue training on it, even without needing to call `Model.compile` again, since it was already compiled before saving. The model is saved a Keras zip archive format, marked by the `.keras` extension. For more information, please refer to [the guide on Keras saving](https://www.tensorflow.org/guide/keras/save_and_serialize).\n",
        "\n",
        "Now, restore the model and train it using a `tf.distribute.Strategy`:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "wROPrJaAqBQz"
      },
      "outputs": [],
      "source": [
        "another_strategy = tf.distribute.OneDeviceStrategy('/cpu:0')\n",
        "with another_strategy.scope():\n",
        "  restored_keras_model_ds = tf.keras.models.load_model(keras_model_path)\n",
        "  restored_keras_model_ds.fit(train_dataset, epochs=2)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "PdiiPmL5tQk5"
      },
      "source": [
        "As the `Model.fit` output shows, loading works as expected with `tf.distribute.Strategy`. The strategy used here does not have to be the same strategy used before saving. "
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3CrXIbmFt0f6"
      },
      "source": [
        "### The `tf.saved_model` API"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "HtGzPp6et4Em"
      },
      "source": [
        "Saving the model with lower-level API is similar to the Keras API:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "4y6T31APuCqK"
      },
      "outputs": [],
      "source": [
        "model = get_model()  # get a fresh model\n",
        "saved_model_path = '/tmp/tf_save'\n",
        "tf.saved_model.save(model, saved_model_path)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "q1QNRYcwuRll"
      },
      "source": [
        "Loading can be done with `tf.saved_model.load`. However, since it is a lower-level API (and hence has a wider range of use cases), it does not return a Keras model. Instead, it returns an object that contain functions that can be used to do inference. For example:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "aaEKqBSPwAuM"
      },
      "outputs": [],
      "source": [
        "DEFAULT_FUNCTION_KEY = 'serving_default'\n",
        "loaded = tf.saved_model.load(saved_model_path)\n",
        "inference_func = loaded.signatures[DEFAULT_FUNCTION_KEY]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "x65l7AaHUZCA"
      },
      "source": [
        "The loaded object may contain multiple functions, each associated with a key. The `\"serving_default\"` key is the default key for the inference function with a saved Keras model. To do inference with this function: "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "5Ore5q8-UjW1"
      },
      "outputs": [],
      "source": [
        "predict_dataset = eval_dataset.map(lambda image, label: image)\n",
        "for batch in predict_dataset.take(1):\n",
        "  print(inference_func(batch))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "osB1LY8WwUJZ"
      },
      "source": [
        "You can also load and do inference in a distributed manner:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "iDYvu12zYTmT"
      },
      "outputs": [],
      "source": [
        "another_strategy = tf.distribute.MirroredStrategy()\n",
        "with another_strategy.scope():\n",
        "  loaded = tf.saved_model.load(saved_model_path)\n",
        "  inference_func = loaded.signatures[DEFAULT_FUNCTION_KEY]\n",
        "\n",
        "  dist_predict_dataset = another_strategy.experimental_distribute_dataset(\n",
        "      predict_dataset)\n",
        "\n",
        "  # Calling the function in a distributed manner\n",
        "  for batch in dist_predict_dataset:\n",
        "    result = another_strategy.run(inference_func, args=(batch,))\n",
        "    print(result)\n",
        "    break"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hWGSukoyw3fF"
      },
      "source": [
        "Calling the restored function is just a forward pass on the saved model (`tf.keras.Model.predict`). What if you want to continue training the loaded function? Or what if you need to embed the loaded function into a bigger model? A common practice is to wrap this loaded object into a Keras layer to achieve this. Luckily, [TF Hub](https://www.tensorflow.org/hub) has [`hub.KerasLayer`](https://github.com/tensorflow/hub/blob/master/tensorflow_hub/keras_layer.py) for this purpose, shown here:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "clfk3hQoyKu6"
      },
      "outputs": [],
      "source": [
        "import tensorflow_hub as hub\n",
        "\n",
        "def build_model(loaded):\n",
        "  x = tf.keras.layers.Input(shape=(28, 28, 1), name='input_x')\n",
        "  # Wrap what's loaded to a KerasLayer\n",
        "  keras_layer = hub.KerasLayer(loaded, trainable=True)(x)\n",
        "  model = tf.keras.Model(x, keras_layer)\n",
        "  return model\n",
        "\n",
        "another_strategy = tf.distribute.MirroredStrategy()\n",
        "with another_strategy.scope():\n",
        "  loaded = tf.saved_model.load(saved_model_path)\n",
        "  model = build_model(loaded)\n",
        "\n",
        "  model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
        "                optimizer=tf.keras.optimizers.Adam(),\n",
        "                metrics=[tf.metrics.SparseCategoricalAccuracy()])\n",
        "  model.fit(train_dataset, epochs=2)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Oe1z_OtSJlu2"
      },
      "source": [
        "In the above example, Tensorflow Hub's `hub.KerasLayer` wraps the result loaded back from `tf.saved_model.load` into a Keras layer that is used to build another model. This is very useful for transfer learning. "
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "KFDOZpK5Wa3W"
      },
      "source": [
        "### Which API should I use?"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GC6GQ9HDLxD6"
      },
      "source": [
        "For saving, if you are working with a Keras model, use the Keras `Model.save` API unless you need the additional control allowed by the low-level API. If what you are saving is not a Keras model, then the lower-level API, `tf.saved_model.save`, is your only choice. \n",
        "\n",
        "For loading, your API choice depends on what you want to get from the model loading API. If you cannot (or do not want to) get a Keras model, then use `tf.saved_model.load`. Otherwise, use `tf.keras.models.load_model`. Note that you can get a Keras model back only if you saved a Keras model. \n",
        "\n",
        "It is possible to mix and match the APIs. You can save a Keras model with `Model.save`, and load a non-Keras model with the low-level API, `tf.saved_model.load`. "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Ktwg2GwnXE8v"
      },
      "outputs": [],
      "source": [
        "model = get_model()\n",
        "\n",
        "# Saving the model using Keras `Model.save`\n",
        "model.save(saved_model_path)\n",
        "\n",
        "another_strategy = tf.distribute.MirroredStrategy()\n",
        "# Loading the model using the lower-level API\n",
        "with another_strategy.scope():\n",
        "  loaded = tf.saved_model.load(saved_model_path)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0Z7lSj8nZiW5"
      },
      "source": [
        "### Saving/Loading from a local device"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "NVAjWcosZodw"
      },
      "source": [
        "When saving and loading from a local I/O device while training on remote devices—for example, when using a Cloud TPU—you must use the option `experimental_io_device` in `tf.saved_model.SaveOptions` and `tf.saved_model.LoadOptions` to set the I/O device to `localhost`. For example:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "jFcuzsI94bNA"
      },
      "outputs": [],
      "source": [
        "model = get_model()\n",
        "\n",
        "# Saving the model to a path on localhost.\n",
        "saved_model_path = '/tmp/tf_save'\n",
        "save_options = tf.saved_model.SaveOptions(experimental_io_device='/job:localhost')\n",
        "model.save(saved_model_path, options=save_options)\n",
        "\n",
        "# Loading the model from a path on localhost.\n",
        "another_strategy = tf.distribute.MirroredStrategy()\n",
        "with another_strategy.scope():\n",
        "  load_options = tf.saved_model.LoadOptions(experimental_io_device='/job:localhost')\n",
        "  loaded = tf.keras.models.load_model(saved_model_path, options=load_options)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hJTWOnC9iuA3"
      },
      "source": [
        "### Caveats"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2cCSZrD7VCxe"
      },
      "source": [
        "One special case is when you create Keras models in certain ways, and then save them before training. For example:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "gurSIbDFjOBc"
      },
      "outputs": [],
      "source": [
        "class SubclassedModel(tf.keras.Model):\n",
        "  \"\"\"Example model defined by subclassing `tf.keras.Model`.\"\"\"\n",
        "\n",
        "  output_name = 'output_layer'\n",
        "\n",
        "  def __init__(self):\n",
        "    super(SubclassedModel, self).__init__()\n",
        "    self._dense_layer = tf.keras.layers.Dense(\n",
        "        5, dtype=tf.dtypes.float32, name=self.output_name)\n",
        "\n",
        "  def call(self, inputs):\n",
        "    return self._dense_layer(inputs)\n",
        "\n",
        "my_model = SubclassedModel()\n",
        "try:\n",
        "  my_model.save(saved_model_path)\n",
        "except ValueError as e:\n",
        "  print(f'{type(e).__name__}: ', *e.args)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "D4qMyXFDSPDO"
      },
      "source": [
        "A SavedModel saves the `tf.types.experimental.ConcreteFunction` objects generated when you trace a `tf.function` (check _When is a Function tracing?_ in the [Introduction to graphs and tf.function](../../guide/intro_to_graphs.ipynb) guide to learn more). If you get a `ValueError` like this it's because `Model.save` was not able to find or create a traced `ConcreteFunction`.\n",
        "\n",
        "**Caution:** You shouldn't save a model without at least one `ConcreteFunction`, since the low-level API will otherwise generate a SavedModel with no `ConcreteFunction` signatures ([learn more](../../guide/saved_model.ipynb) about the SavedModel format). For example:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "064SE47mYDj8"
      },
      "outputs": [],
      "source": [
        "tf.saved_model.save(my_model, saved_model_path)\n",
        "x = tf.saved_model.load(saved_model_path)\n",
        "x.signatures"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "LRTxlASJX-cY"
      },
      "source": [
        "\n",
        "Usually the model's forward pass—the `call` method—will be traced automatically when the model is called for the first time, often via the Keras `Model.fit` method. A `ConcreteFunction` can also be generated by the Keras [Sequential](https://www.tensorflow.org/guide/keras/sequential_model) and [Functional](https://www.tensorflow.org/guide/keras/functional) APIs, if you set the input shape, for example, by making the first layer either a `tf.keras.layers.InputLayer` or another layer type, and passing it the `input_shape` keyword argument. \n",
        "\n",
        "To verify if your model has any traced `ConcreteFunction`s, check if `Model.save_spec` is `None`:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "xAXise4eR0YJ"
      },
      "outputs": [],
      "source": [
        "print(my_model.save_spec() is None)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "G2G_FQrWJAO3"
      },
      "source": [
        "Let's use `tf.keras.Model.fit` to train the model, and notice that the `save_spec` gets defined and model saving will work:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "cv5LTi0zDkKS"
      },
      "outputs": [],
      "source": [
        "BATCH_SIZE_PER_REPLICA = 4\n",
        "BATCH_SIZE = BATCH_SIZE_PER_REPLICA * mirrored_strategy.num_replicas_in_sync\n",
        "\n",
        "dataset_size = 100\n",
        "dataset = tf.data.Dataset.from_tensors(\n",
        "    (tf.range(5, dtype=tf.float32), tf.range(5, dtype=tf.float32))\n",
        "    ).repeat(dataset_size).batch(BATCH_SIZE)\n",
        "\n",
        "my_model.compile(optimizer='adam', loss='mean_squared_error')\n",
        "my_model.fit(dataset, epochs=2)\n",
        "\n",
        "print(my_model.save_spec() is None)\n",
        "my_model.save(saved_model_path)"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "name": "save_and_load.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
